##
## Licensed to the .NET Foundation under one or more agreements.
## The .NET Foundation licenses this file to you under the MIT license.
##

import os
from datetime import datetime

from utils import next_power_of_2
from bitonic_isa import BitonicISA


native_size_map = {
    "int32_t": 4,
    "uint32_t": 4,
}

class NEONBitonicISA(BitonicISA):
    def __init__(self, type):
        self.vector_size_in_bytes = 16

        self.type = type

        self.bitonic_size_map = {}

        for t, s in native_size_map.items():
            self.bitonic_size_map[t] = int(self.vector_size_in_bytes / s)

        self.bitonic_type_map = { 
            "int32_t": "int32x4_t",
            "uint32_t": "uint32x4_t",
        }

        self.bitonic_func_suffix_type_map = {
            "int32_t": "s32",
            "uint32_t": "u32",
        }

    def max_bitonic_sort_vectors(self):
        return 16

    def vector_size(self):
        return self.bitonic_size_map[self.type]

    def vector_type(self):
        return self.bitonic_type_map[self.type]

    def vector_suffix(self):
        return self.bitonic_func_suffix_type_map[self.type]

    @classmethod
    def supported_types(cls):
        return native_size_map.keys()

    def generate_param_list(self, start, numParams):
        return str.join(", ", list(map(lambda p: f"d{p:02d}", range(start, start + numParams))))

    def generate_param_def_list(self, numParams):
        t = self.type
        return str.join(", ", list(map(lambda p: f"{self.vector_type()}& d{p:02d}", range(1, numParams + 1))))

    def generate_sort_vec(self, merger, left, right, temp):
        if merger:
            return f"""
        tmp = {left};
        {left} = vminq_{self.vector_suffix()}({right}, {left});
        {right} = vmaxq_{self.vector_suffix()}({right}, tmp);"""
        else:
            return f"""
        tmp = {right};
        {right} = vmaxq_{self.vector_suffix()}({left}, {right});
        {left} = vminq_{self.vector_suffix()}({left}, tmp);"""

    def autogenerated_blabber(self):
        return f"""// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

/////////////////////////////////////////////////////////////////////////////
////
// This file was auto-generated by a tool at {datetime.now().strftime("%F %H:%M:%S")}
//
// It is recommended you DO NOT directly edit this file but instead edit
// the code-generator that generated this source file instead.
/////////////////////////////////////////////////////////////////////////////"""

    def generate_prologue(self, f):
        t = self.type
        s = f"""{self.autogenerated_blabber()}

#ifndef BITONIC_SORT_NEON_{t.upper()}_H
#define BITONIC_SORT_NEON_{t.upper()}_H

#include "bitonic_sort.h"
#include <arm_neon.h>

namespace vxsort {{
namespace smallsort {{


template<> struct bitonic<{t}, NEON> {{
    static const int N = {self.vector_size()};
    static constexpr {t} MAX = std::numeric_limits<{t}>::max();

    static const {self.type} idxArray[];
    static const {self.type} maxvArray[];
    static const {self.type} mask1Array[];
    static const {self.type} mask2Array[];
    static const {self.type} mask13Array[];
    static const {self.type} mask23Array[];
public:
"""
        print(s, file=f)

    def generate_epilogue(self, f):
        s = f"""
}};
}}
}}

#endif // BITONIC_SORT_NEON

    """
        print(s, file=f)

    def generate_1v_basic_sorters(self, f, ascending):
        g = self

        suffix = "ascending" if ascending else "descending"
        vtrnq_args = "mn1, mx1" if ascending else "mx1, mn1"
        vcombine_arg1 = "mn2" if ascending else "mx2"
        vcombine_arg2 = "mx2" if ascending else "mn2"
        vextq_arg = "mx12" if ascending else "mn12"
        vbslq_arg = "mn12" if ascending else "mx12"

        s = f"""    static INLINE void sort_01v_{suffix}({g.generate_param_def_list(1)}) {{

        // Sort (0,1) and (2,3)
        {self.vector_type()} b   = vrev64q_{self.vector_suffix()}(d01);
        {self.vector_type()} mn1 = vminq_{self.vector_suffix()}(d01, b);
        {self.vector_type()} mx1 = vmaxq_{self.vector_suffix()}(d01, b);
        uint32x4x2_t t1 = vtrnq_{self.vector_suffix()}({vtrnq_args});
        d01 = t1.val[0];

        // Sort (0,2) and (1,3)
        {self.vector_type()} sh2  = vextq_{self.vector_suffix()}(d01, d01, 2);
        {self.vector_type()} mn2  = vminq_{self.vector_suffix()}(d01, sh2);
        {self.vector_type()} mx2  = vmaxq_{self.vector_suffix()}(d01, sh2);
        d01 = vcombine_{self.vector_suffix()}(vget_low_{self.vector_suffix()}({vcombine_arg1}), vget_high_{self.vector_suffix()}({vcombine_arg2}));

        // Sort (1,2)
        {self.vector_type()} sh1   = vextq_{self.vector_suffix()}(d01, d01, 1);
        {self.vector_type()} mn12  = vminq_{self.vector_suffix()}(d01, sh1);
        {self.vector_type()} mx12  = vmaxq_{self.vector_suffix()}(d01, sh1);
        {self.vector_type()} rot = vextq_{self.vector_suffix()}({vextq_arg}, {vextq_arg}, 3);
        const {self.vector_type()} mask1 = vld1q_{self.vector_suffix()}(mask1Array);
        const {self.vector_type()} mask2 = vld1q_{self.vector_suffix()}(mask2Array);
        d01 = vbslq_{self.vector_suffix()}(mask1, {vbslq_arg}, d01);
        d01 = vbslq_{self.vector_suffix()}(mask2, rot, d01);
    }}\n"""
        print(s, file=f)


    def generate_1v_merge_sorters(self, f, ascending: bool):
        g = self

        suffix = "ascending" if ascending else "descending"
        vbslq1_arg2 = "hi" if ascending else "lo"
        vbslq1_arg3 = "lo" if ascending else "hi"
        vbslq2_arg2 = "mx" if ascending else "mn"
        vbslq2_arg3 = "mn" if ascending else "mx"

        s = f"""    static INLINE void sort_01v_merge_{suffix}({g.generate_param_def_list(1)}) {{
        const {self.vector_type()} mask13 = vld1q_{self.vector_suffix()}(mask13Array);
        const {self.vector_type()} mask23 = vld1q_{self.vector_suffix()}(mask23Array);

        // Cross half compare
        {self.vector_type()} t  = vrev64q_{self.vector_suffix()}(d01);
        t  = vextq_{self.vector_suffix()}(t, t, 2);
        {self.vector_type()} lo = vminq_{self.vector_suffix()}(d01, t);
        {self.vector_type()} hi = vmaxq_{self.vector_suffix()}(d01, t);
        d01 = vbslq_{self.vector_suffix()}(mask23, {vbslq1_arg2}, {vbslq1_arg3});

        // Sort (0,2) and (1,3)
        {self.vector_type()} sh = vextq_{self.vector_suffix()}(d01, d01, 2);
        {self.vector_type()} mn = vminq_{self.vector_suffix()}(d01, sh);
        {self.vector_type()} mx = vmaxq_{self.vector_suffix()}(d01, sh);
        d01 = vbslq_{self.vector_suffix()}(mask23, {vbslq2_arg2}, {vbslq2_arg3});

        // Sort (0,1) and (2,3)
        sh = vrev64q_{self.vector_suffix()}(d01);
        mn = vminq_{self.vector_suffix()}(d01, sh);
        mx = vmaxq_{self.vector_suffix()}(d01, sh);
        d01 = vbslq_{self.vector_suffix()}(mask13, {vbslq2_arg2}, {vbslq2_arg3});
    }}\n"""
        print(s, file=f)

    def generate_compounded_sorter(self, f, width, ascending, inline):
        g = self

        w1 = int(next_power_of_2(width) / 2)
        w2 = int(width - w1)

        suffix = "ascending" if ascending else "descending"
        rev_suffix = "descending" if ascending else "ascending"
        cmptype = ">" if ascending else "<"

        inl = "INLINE" if inline else "NOINLINE"

        s = f"""    static {inl} void sort_{width:02d}v_{suffix}({g.generate_param_def_list(width)}) {{
        {self.vector_type()} tmp;

        sort_{w1:02d}v_{suffix}({g.generate_param_list(1, w1)});
        sort_{w2:02d}v_{rev_suffix}({g.generate_param_list(w1 + 1, w2)});"""
        print(s, file=f)

        for r in range(w1 + 1, width + 1):
            x = w1 + 1 - (r - w1)
            s = self.generate_sort_vec(False, f"d{x:02d}", f"d{r:02d}", "tmp")
            print(s, file=f)

        s = f"""
        sort_{w1:02d}v_merge_{suffix}({g.generate_param_list(1, w1)});
        sort_{w2:02d}v_merge_{suffix}({g.generate_param_list(w1 + 1, w2)});"""
        print(s, file=f)
        print("    }\n", file=f)


    def generate_compounded_merger(self, f, width, ascending, inline):
        g = self

        w1 = int(next_power_of_2(width) / 2)
        w2 = int(width - w1)

        suffix = "ascending" if ascending else "descending"
        rev_suffix = "descending" if ascending else "ascending"
        cmptype = ">" if ascending else "<"

        inl = "INLINE" if inline else "NOINLINE"

        s = f"""    static {inl} void sort_{width:02d}v_merge_{suffix}({g.generate_param_def_list(width)}) {{
        {self.vector_type()} tmp;"""
        print(s, file=f)

        for r in range(w1 + 1, width + 1):
            x = r - w1
            s = self.generate_sort_vec(True, f"d{x:02d}", f"d{r:02d}", "tmp")
            print(s, file=f)

        s = f"""
        sort_{w1:02d}v_merge_{suffix}({g.generate_param_list(1, w1)});
        sort_{w2:02d}v_merge_{suffix}({g.generate_param_list(w1 + 1, w2)});"""
        print(s, file=f)
        print("    }\n", file=f)



    def generate_entry_points(self, f):
        type = self.type
        vt = self.bitonic_type_map[type];
        g = self

        for m in range(1, g.max_bitonic_sort_vectors() + 1):

            s = f"""
    static NOINLINE void sort_{m:02d}v_alt({type} *ptr, int remainder) {{
        const {self.vector_type()} maxv = vld1q_{self.vector_suffix()}(maxvArray);
        const {self.vector_type()} idx = vld1q_{self.vector_suffix()}(idxArray);
        const {self.vector_type()} mask = vcltq_{self.vector_suffix()}(idx, vdupq_n_{self.vector_suffix()}(remainder ? remainder : 4));\n"""
            print(s, file=f)

            for l in range(0, m-1):
                s = f"""        {self.vector_type()} d{l + 1:02d} = vld1q_{self.vector_suffix()}(reinterpret_cast<const uint32_t*>(ptr + {l * self.vector_size()}));"""
                print(s, file=f)
            l = m - 1

            s = f"""        {self.vector_type()} d{l + 1:02d}_orig = vld1q_{self.vector_suffix()}(reinterpret_cast<const uint32_t*>(ptr + {l * self.vector_size()}));
        {self.vector_type()} d{l + 1:02d} = vbslq_{self.vector_suffix()}(mask, d{l + 1:02d}_orig, maxv);"""
            print(s, file=f)

            s = f"\n        sort_{m:02d}v_ascending({g.generate_param_list(1, m)});\n"
            print(s, file=f)

            for l in range(0, m-1):
                s = f"""        vst1q_{self.vector_suffix()}(reinterpret_cast<uint32_t*>(ptr + {l * self.vector_size()}), d{l + 1:02d});"""
                print(s, file=f)
            l = m - 1

            s = f"""        vst1q_{self.vector_suffix()}(reinterpret_cast<uint32_t*>(ptr + {l * self.vector_size()}), vbslq_{self.vector_suffix()}(mask, d{l + 1:02d}, d{l + 1:02d}_orig));"""
            print(s, file=f)

            print("    }\n", file=f)

    def generate_master_entry_point(self, f_header, f_src):
        basename = os.path.basename(f_header.name)
        s = f"""// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#include "common.h"
#include "{basename}"

using namespace vxsort;
"""
        print(s, file=f_src)

        t = self.type
        g = self

        s = f"""    static void sort({t} *ptr, size_t length);"""
        print(s, file=f_header)

        s = f"""void vxsort::smallsort::bitonic<{t}, vector_machine::NEON >::sort({t} *ptr, size_t length) {{
    const auto fullvlength = length / N;
    const int remainder = (int) (length - fullvlength * N);
    const auto v = fullvlength + ((remainder > 0) ? 1 : 0);
    switch(v) {{"""
        print(s, file=f_src)

        for m in range(1, self.max_bitonic_sort_vectors() + 1):
            s = f"        case {m}: sort_{m:02d}v_alt(ptr, remainder); break;"
            print(s, file=f_src)
        print("    }", file=f_src)
        print("}", file=f_src)

        s = f"""
    const {self.type} vxsort::smallsort::bitonic<{t}, vector_machine::NEON >::idxArray[4]   = {{0u, 1u, 2u, 3u}};
    const {self.type} vxsort::smallsort::bitonic<{t}, vector_machine::NEON >::maxvArray[4]  = {{MAX, MAX, MAX, MAX}};
    const {self.type} vxsort::smallsort::bitonic<{t}, vector_machine::NEON >::mask1Array[4] = {{0u, ~0u, 0u, 0u}};
    const {self.type} vxsort::smallsort::bitonic<{t}, vector_machine::NEON >::mask2Array[4] = {{0u, 0u, ~0u, 0u}};
    const {self.type} vxsort::smallsort::bitonic<{t}, vector_machine::NEON >::mask13Array[4] = {{0u, ~0u, 0u, ~0u}};
    const {self.type} vxsort::smallsort::bitonic<{t}, vector_machine::NEON >::mask23Array[4] = {{0u, 0u, ~0u, ~0u}};
        """;
        print(s, file=f_src)


        pass
